#' ---
#' title: "Campaign Communication and Legislative Leadership (PSRM)"
#' subtitle: "05_performance_all_classifiers.R"
#' author: "Authors: Stefan Mueller and Naofumi Fujimura"
#' date: "Note: Code compiled successfully on `r format(Sys.time(), '%d %B %Y')`"
#' ---

# load packages
library(dplyr)               # CRAN v1.1.2
library(ggplot2)             # CRAN v3.4.2
library(ggrepel)             # CRAN v0.9.3
library(mltest)              # CRAN v1.0.1
library(quanteda)            # CRAN v3.3.1
library(quanteda.textmodels) # CRAN v0.9.6
library(quanteda.textstats)  # CRAN v0.96.3
library(xtable)              # CRAN v1.8-4
library(stringr)             # CRAN v1.5.0


# If the code does not run, one or more packages may have been 
# updated, which may result in errors or conflicts. You can solve this issue
# by installing the package version listed above or by using the 
# groundhog package:
# after installing groundhog using install.packages("groundhog")
# change library(name_of_package) to
# groundhog::groundhog.library(name_of_package, date = "2024-01-31")
# Instead of adjusting the library() function for each package, 
# you can adjust them at all once using the
# the following syntax:
# groundhog.library("library('pkgA')
#                   library('pkgB')
#                   library('pkgC')", date = "2024-01-31")
# More details are available at: https://groundhogr.com/using/

# make sure to use quanteda package version 3.3.1 
# if you have issues reproducing the analysis
packageVersion("quanteda")

# make sure to use quanteda.textmodels package version 0.9.6
# if you have issues reproducing the analysis
packageVersion("quanteda.textmodels")

# make sure to use quanteda.textstats package version 0.96.3
# if you have issues reproducing the analysis
packageVersion("quanteda.textstats")

# print output of sessionInfo()
sessionInfo()

# load custom ggplot2 scheme
source("function_theme_base.R")

# load predicted classes based on BERT classification
dat_predicted <- read.csv("data_test_predicted_bert.csv")

# should be 485 statements (15 removed since about Judicial Affairs, see paper for details)
nrow(dat_predicted)

# change numbers to policy areas
dat_predicted <- dat_predicted %>% 
    mutate(policy_area_pred_clean = dplyr::recode(label,
                                                "0" = "Agriculture, Forestry, and Fisheries",
                                                "1" = "Committees on Cabinet",
                                                "2" = "Economy, Trade and Industry",
                                                "3" = "Education, Culture, Sports, Science, and Technology",
                                                "4" = "Environment",
                                                "5" = "Financial Affairs",
                                                "6" = "Foreign Affairs",
                                                "7" = "Health, Labor, and Welfare",
                                                "8" = "Internal Affairs and Communications",
                                                "9" = "Land, Infrastructure, Transport, and Tourism",
                                                "10" = "No policy area",
                                                "11" = "Security")) |> 
    mutate(policy_area = str_replace_all(policy_area, "Labour", "Labor"))


# Figure 01 ----

# check count of statements in test set based on human codings
dat_pred_sum_humans <- dat_predicted |> 
    group_by(policy_area) |> 
    count() |> 
    rename(n_humans = n) |> 
    mutate(policy_area = str_replace_all(policy_area, "Labour", "Labor"))

# do the same for the BERT predictions
dat_pred_sum_bert <- dat_predicted |> 
    group_by(policy_area_pred_clean) |> 
    count() |> 
    rename(n_bert = n,
           policy_area = policy_area_pred_clean)


# merge both datasets, exclude "no policy area"
dat_joined_sum <- left_join(dat_pred_sum_bert,
                            dat_pred_sum_humans,
                            by = "policy_area") |> 
    filter(policy_area != "No policy area")


# calculate correlations
cor <- round(cor(dat_joined_sum$n_bert,
                 dat_joined_sum$n_humans), 2)

cor

set.seed(2335)
ggplot(dat_joined_sum, aes(x = n_bert, y = n_humans)) +
    #geom_abline(linetype = "dashed", colour = "grey50") +
    geom_smooth(method = "lm", alpha = 1, 
                colour = "grey30", fill = "grey80") +
    ggrepel::geom_label_repel(aes(label = policy_area),
                              max.overlaps = 10) +
    geom_point(size = 4) +
    annotate("text", label = paste0("r=", cor),
             x = 5, y = 60,
             size = 5, colour = "grey50") +
    scale_x_continuous(limits = c(0, 65),
                       breaks = c(seq(0, 80, 10))) +
    scale_y_continuous(limits = c(0, 65),
                       breaks = c(seq(0, 80, 10))) +
    labs(x = "Frequency of Policy Area in Held-Out Test Set (BERT)",
         y = "Frequency of Policy Area in Held-Out Test Set (Human Coding)")
ggsave("fig_01.pdf",
       width = 9, height = 6)




# calculate performance metrics for BERT classification

# requires adjustment of confusion matrix
u <- union(dat_predicted$policy_area_pred_clean, dat_predicted$policy_area)

t <- table(factor(dat_predicted$policy_area_pred_clean, u), 
           factor(dat_predicted$policy_area, u))

# get classifier performance
performance_bert <- ml_test(predicted = factor(dat_predicted$policy_area_pred_clean, u),
        true = factor(dat_predicted$policy_area, u),
        output.as.table = TRUE) |> 
    mutate(classifier = "BERT")
performance_bert$class <- rownames(performance_bert)

# average F1 score
mean(performance_bert$F1)
                             # [1] 0.7348579



# compare with SVM and NB classification

# training dataset


dat_train <- read.csv("data_sentences_train.csv",
                      fileEncoding = "utf-8")

# make it "easier" for classifier by also considering
# the eval data from BERT model in the training data
dat_eval <- read.csv("data_sentences_eval.csv",
                     fileEncoding = "utf-8")


dat_train_all <- bind_rows(dat_train, dat_eval) |> 
    mutate(train_test = "train")

# get held-out test set
dat_test <- read.csv("data_sentences_test.csv",
                     fileEncoding = "utf-8") |> 
    mutate(train_test = "test")


# bind all data into one data frame
dat_all <- bind_rows(dat_test, dat_train_all) |> 
    mutate(policy_area = str_replace_all(policy_area, "Labour", "Labor"))

# get that the frequencies are 2500 for train and 485 for test
table(dat_all$train_test)

# create text corpus
corp <- corpus(dat_all)

# tokenize based on recommendations at:
# https://tutorials.quanteda.io/multilingual/japanese/

toks <- tokens(corp, what = "word3")

toks_refi <- toks

tstat_kanji <- toks %>% 
    tokens_select('^[一-龠]+$', valuetype = 'regex', padding = TRUE) %>% 
    textstat_collocations(min_count = 5, tolower = FALSE) |> 
    filter(z > 2)

toks_refi <- tokens_compound(toks_refi, tstat_kanji,
                             concatenator = '', join = TRUE)

tstat_kana <- toks_refi %>% 
    tokens_select('^[ァ-ヶー]+$', valuetype = 'regex', padding = TRUE) %>% 
    textstat_collocations(min_count = 5, tolower = FALSE) |> 
    filter(z > 2)

toks_refi <- tokens_compound(toks_refi, tstat_kana,
                             concatenator = '', join = TRUE)

tstast_any <- toks_refi %>% 
    tokens_select('^[０-９ァ-ヶー一-龠]+$', valuetype = 'regex', padding = TRUE) %>% 
    textstat_collocations(min_count = 5, tolower = FALSE) |> 
    filter(z > 2)

toks_refi <- tokens_compound(toks_refi, tstast_any,
                             concatenator = '', join = TRUE)

dfmat_train <- toks_refi %>% 
    tokens_subset(train_test == "train") %>% 
    dfm()

dfmat_test <- toks_refi %>% 
    tokens_subset(train_test == "test") %>% 
    dfm()

# train SVM
tmod_svm <- textmodel_svm(dfmat_train, dfmat_train$policy_area)

# predict held-out test data
pred_svm <- suppressWarnings(predict(tmod_svm, dfmat_test, force = TRUE))


# get performance metridcs and store as data frame
performance_svm <- ml_test(predicted = factor(pred_svm, u),
                           true = factor(dfmat_test$policy_area, u),
                           output.as.table = TRUE) |> 
    mutate(classifier = "SVM")
performance_svm$class <- rownames(performance_svm)

# train Naive Bayes
tmod_nb <- textmodel_nb(dfmat_train, dfmat_train$policy_area)

# predict held-out test data
pred_nb <- suppressWarnings(predict(tmod_nb, dfmat_test, force = TRUE))

# get performance metridcs and store as data frame
performance_nb <- ml_test(predicted = factor(pred_nb, u),
                          true = factor(dfmat_test$policy_area, u),
                          output.as.table = TRUE) |> 
    mutate(classifier = "Naive Bayes")
performance_nb$class <- rownames(performance_nb)


# Table A03 ----
# compare performance metrics for the three classifiers

# get performance metrics for the three combinations
dat_perf_all <- bind_rows(performance_nb,
                          performance_bert,
                          performance_svm)


# get clean data frame for BERT
performance_bert_clean <- performance_bert |> 
    dplyr::select(class, 
                  `F1 (BERT)` = F1,
                  `Precision (BERT)` = precision,
                  `Recall (BERT)` = recall,
                  `Bal. Acc (BERT)` = balanced.accuracy)

nrow(performance_bert_clean)
summary(performance_bert_clean$`F1 (BERT)`)

# get clean data frame for SVM
performance_svm_clean <- performance_svm |> 
    dplyr::select(class, 
                  `F1 (SVM)` = F1,
                  `Precision (SVM)` = precision,
                  `Recall (SVM)` = recall,
                  `Bal. Acc (SVM)` = balanced.accuracy)

# get clean data frame for NB
performance_nb_clean <- performance_nb |> 
    dplyr::select(class, 
                  `F1 (NB)` = F1,
                  `Precision (NB)` = precision,
                  `Recall (NB)` = recall,
                  `Bal. Acc (NB)` = balanced.accuracy)


# join to one data frame and arrange columns
dat_predicted_sents <- left_join(performance_bert_clean,
                                 performance_svm_clean,
                                 by = "class") |> 
    arrange(-`F1 (BERT)`) |> 
    left_join(performance_nb_clean, by = "class") |> 
    dplyr::select(Category = class,
                  starts_with("F1"), 
                  starts_with("Precision"),
                  starts_with("Recall"),
                  everything())

# get number of human-coded sentences per category

dat_n <- dat_predicted |> 
    group_by(policy_area) |> 
    count() |> 
    rename(`N Statements` = n,
           `Category` = policy_area)

# join with data frame
dat_predicted_sents <- left_join(dat_predicted_sents,
                                 dat_n, by = "Category")


# print table
dat_predicted_sents

# save Table A03 as html file for paper
print(xtable(dat_predicted_sents, digits = 2), 
      file = "tab_a03.html",
      type = "html", include.rownames = FALSE)
